import os
import torch
from models.smart_model import SMART
from train.train import train_model, evaluate_model
from train.utils import load_data, print_metrics

dataset_dir = './data/malevis'


if not os.path.exists(dataset_dir):
    raise FileNotFoundError(f"Dataset directory {dataset_dir} not found. Please make sure the Malevis dataset is in this directory.")

train_loader, val_loader, class_names = load_data(dataset_dir)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("SMART model")
model = SMART(num_classes=len(class_names)).to(device)
print("model built")

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

# Training and evaluation loop
epochs = 10
print("epochs starting")

for epoch in range(epochs):
    loss = train_model(model, train_loader, optimizer, criterion, device)
    print(f"Epoch {epoch+1}/{epochs}, Loss: {loss:.4f}")

# Evaluation
all_preds, all_labels = evaluate_model(model, val_loader, device)

print("\nPerformance Metrics:")
print_metrics(all_labels, all_preds, class_names)
